{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### LightGBM(gbdt)模型测试（tfidf特征）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import lightgbm as lgb\n",
    "from lightgbm.sklearn import LGBMClassifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "数据准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>feat_1_tfidf</th>\n",
       "      <th>feat_2_tfidf</th>\n",
       "      <th>feat_3_tfidf</th>\n",
       "      <th>feat_4_tfidf</th>\n",
       "      <th>feat_5_tfidf</th>\n",
       "      <th>feat_6_tfidf</th>\n",
       "      <th>feat_7_tfidf</th>\n",
       "      <th>feat_8_tfidf</th>\n",
       "      <th>feat_9_tfidf</th>\n",
       "      <th>...</th>\n",
       "      <th>feat_84_tfidf</th>\n",
       "      <th>feat_85_tfidf</th>\n",
       "      <th>feat_86_tfidf</th>\n",
       "      <th>feat_87_tfidf</th>\n",
       "      <th>feat_88_tfidf</th>\n",
       "      <th>feat_89_tfidf</th>\n",
       "      <th>feat_90_tfidf</th>\n",
       "      <th>feat_91_tfidf</th>\n",
       "      <th>feat_92_tfidf</th>\n",
       "      <th>feat_93_tfidf</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.421803</td>\n",
       "      <td>0.052224</td>\n",
       "      <td>0.842245</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>0.068759</td>\n",
       "      <td>0.094897</td>\n",
       "      <td>0.502059</td>\n",
       "      <td>0.535424</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.143963</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.070171</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.071486</td>\n",
       "      <td>0.648348</td>\n",
       "      <td>0.050417</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.078248</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.071995</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.048481</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.139311</td>\n",
       "      <td>0.034257</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>0.048446</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.047156</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.065018</td>\n",
       "      <td>0.081172</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.556178</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 94 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   id  feat_1_tfidf  feat_2_tfidf  feat_3_tfidf  feat_4_tfidf  feat_5_tfidf  \\\n",
       "0   1      0.000000      0.000000      0.000000      0.000000           0.0   \n",
       "1   2      0.068759      0.094897      0.502059      0.535424           0.0   \n",
       "2   3      0.000000      0.071486      0.648348      0.050417           0.0   \n",
       "3   4      0.000000      0.000000      0.000000      0.048481           0.0   \n",
       "4   5      0.048446      0.000000      0.000000      0.047156           0.0   \n",
       "\n",
       "   feat_6_tfidf  feat_7_tfidf  feat_8_tfidf  feat_9_tfidf      ...        \\\n",
       "0           0.0      0.000000      0.000000           0.0      ...         \n",
       "1           0.0      0.000000      0.000000           0.0      ...         \n",
       "2           0.0      0.000000      0.000000           0.0      ...         \n",
       "3           0.0      0.000000      0.000000           0.0      ...         \n",
       "4           0.0      0.065018      0.081172           0.0      ...         \n",
       "\n",
       "   feat_84_tfidf  feat_85_tfidf  feat_86_tfidf  feat_87_tfidf  feat_88_tfidf  \\\n",
       "0            0.0       0.000000       0.421803       0.052224       0.842245   \n",
       "1            0.0       0.000000       0.000000       0.000000       0.000000   \n",
       "2            0.0       0.000000       0.000000       0.000000       0.078248   \n",
       "3            0.0       0.139311       0.034257       0.000000       0.000000   \n",
       "4            0.0       0.000000       0.000000       0.000000       0.000000   \n",
       "\n",
       "   feat_89_tfidf  feat_90_tfidf  feat_91_tfidf  feat_92_tfidf  feat_93_tfidf  \n",
       "0       0.000000            0.0       0.000000       0.000000       0.000000  \n",
       "1       0.143963            0.0       0.000000       0.070171       0.000000  \n",
       "2       0.000000            0.0       0.000000       0.000000       0.071995  \n",
       "3       0.000000            0.0       0.000000       0.000000       0.000000  \n",
       "4       0.000000            0.0       0.556178       0.000000       0.000000  \n",
       "\n",
       "[5 rows x 94 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dpath = \"../data/\"\n",
    "test = pd.read_csv(dpath + \"Otto_FE_test_tfidf.csv\")\n",
    "test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.sparse import csr_matrix\n",
    "test_id = test[\"id\"]\n",
    "X_test = test.iloc[:,1:94]\n",
    "X_test = csr_matrix(X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "lgbc = pickle.load(open(\"otto_lgbmgbdt_tfidf.pkl\",\"rb\"))\n",
    "y_test_pred = lgbc.predict_proba(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(144368, 9)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_test_pred.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "提交结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_df = pd.DataFrame(y_test_pred)\n",
    "\n",
    "columns = np.empty(9, dtype=object)\n",
    "for i in range(9):\n",
    "    columns[i] = 'Class_' + str(i+1)\n",
    "\n",
    "out_df.columns = columns\n",
    "\n",
    "out_df = pd.concat([test_id,out_df], axis = 1)\n",
    "out_df.to_csv(\"LightGBM_org_tfidf.csv\", index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
